--- title: Cross-validation date: 2019-01-28T13:30:00-06:00 # Schedule page publish date. draft: false type: docs bibliography: [../../static/bib/sources.bib] csl: [../../static/bib/apa.csl] link-citations: true menu: notes: parent: Resampling methods weight: 1 ---
library(tidyverse)
library(tidymodels)
library(magrittr)
library(here)
library(rcfss)
set.seed(1234)
theme_set(theme_minimal())Resampling methods are essential to test and evaluate statistical models. Because you likely do not have the resources or capabilities to repeatedly sample from your population of interest, instead you can repeatedly draw from your original sample to obtain additional information about your model. For instance, you could repeatedly draw samples from your data, estimate a linear regression model on each sample, and then examine how the estimated model differs across each sample. This allows you to assess the variability and stability of your model in a way not possible if you can only fit the model once.
There are two major types of resampling methods we will consider:
In most modeling situations, we can immediately partition the dataset into a training set and a test set. The training set will be used for model construction, and the test set will be used to evaluate the performance of the final model. This is most important – while you can reuse the training set many times to build different statistical models, you can only use the test set of data once. If you reuse it, you introduce data leakage into your modeling process and no longer have unbiased estimates of the test error. This is why collaborative platforms such as Kaggle hold back a portion of the dataset in their competitions. You can use the training set to build the strongest performing model, but you cannot tune your model based on the test error because you do not have access to it.
Even accounting for the training/test set split, one issue with using the same data to both fit and evaluate our model is that we will bias our model towards fitting the data that we have. We may fit our function to create the results we expect or desire, rather than the “true” function. Instead, we can further split our training set into distinct training and validation sets. The training set can be used repeatedly to train different models. We then use the validation set to evaluate the model’s performance, generating metrics such as the mean squared error (MSE) or the error rate. Unlike the test set, we are permitted to use the validation set multiple times. The important thing is that we do not use the validation set to train or fit the model, only evaluate its performance after it has been fit.
Here we will examine the relationship between horsepower and car mileage in the Auto dataset (found in library(ISLR)):
library(ISLR)
Auto <- as_tibble(Auto)
Auto## # A tibble: 392 x 9
## mpg cylinders displacement horsepower weight acceleration year origin
## <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 18 8 307 130 3504 12 70 1
## 2 15 8 350 165 3693 11.5 70 1
## 3 18 8 318 150 3436 11 70 1
## 4 16 8 304 150 3433 12 70 1
## 5 17 8 302 140 3449 10.5 70 1
## 6 15 8 429 198 4341 10 70 1
## 7 14 8 454 220 4354 9 70 1
## 8 14 8 440 215 4312 8.5 70 1
## 9 14 8 455 225 4425 10 70 1
## 10 15 8 390 190 3850 8.5 70 1
## # … with 382 more rows, and 1 more variable: name <fct>
ggplot(Auto, aes(horsepower, mpg)) +
geom_point()
The relationship does not appear to be strictly linear:
ggplot(Auto, aes(horsepower, mpg)) +
geom_point() +
geom_smooth(method = "lm", se = FALSE)
Perhaps by adding quadratic terms to the linear regression we could improve overall model fit. To evaluate the model, we will split the data into a training set and validation set,1 estimate a series of higher-order models, and calculate a test statistic summarizing the accuracy of the estimated mpg. To calculate the accuracy of the model, we will use mean squared error (MSE), defined as
\[MSE = \frac{1}{N} \sum_{i = 1}^{N}{(y_i - \hat{f}(x_i))^2}\]
For this task, first we use rsample::initial_split() to create training and validation sets (using a 50/50 split), then estimate a linear regression model without any quadratic terms.
set.seed() in the beginning - whenever you are writing a script that involves randomization (here, random subsetting of the data), always set the seed at the beginning of the script. This ensures the results can be reproduced precisely.2glm() function rather than lm() - if you don’t change the family parameter, the results of lm() and glm() are exactly the same.3set.seed(1234)
auto_split <- initial_split(data = Auto, prop = 0.5)
auto_train <- training(auto_split)
auto_test <- testing(auto_split)auto_lm <- glm(mpg ~ horsepower, data = auto_train)
summary(auto_lm)##
## Call:
## glm(formula = mpg ~ horsepower, data = auto_train)
##
## Deviance Residuals:
## Min 1Q Median 3Q Max
## -13.7105 -3.4442 -0.5342 2.6256 15.1015
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 40.057910 1.054798 37.98 <2e-16 ***
## horsepower -0.157604 0.009402 -16.76 <2e-16 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## (Dispersion parameter for gaussian family taken to be 24.80151)
##
## Null deviance: 11780.6 on 195 degrees of freedom
## Residual deviance: 4811.5 on 194 degrees of freedom
## AIC: 1189.6
##
## Number of Fisher Scoring iterations: 2
To estimate the MSE for a single partition (i.e. for a training or validation set):
broom::augment() to generate predicted values for the data setFor the training set, this would look like:
(train_mse <- augment(auto_lm, newdata = auto_train) %>%
mse(truth = mpg, estimate = .fitted))## # A tibble: 1 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 mse standard 24.5
Note the special use of the
$%$pipe operator from themagrittrpackage. This allows us to directly access columns from the data frame entering the pipe. This is especially useful for integrating non-tidy functions into a tidy operation.
For the validation set:
(test_mse <- augment(auto_lm, newdata = auto_test) %>%
mse(truth = mpg, estimate = .fitted))## # A tibble: 1 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 mse standard 23.4
For a strictly linear model, the MSE for the validation set is 23.38. How does this compare to a quadratic model? We can use the poly() function in conjunction with a map() iteration to estimate the MSE for a series of models with higher-order polynomial terms:
# visualize each model
ggplot(Auto, aes(horsepower, mpg)) +
geom_point(alpha = .1) +
geom_smooth(aes(color = "1"),
method = "glm",
formula = y ~ poly(x, i = 1),
se = FALSE) +
geom_smooth(aes(color = "2"),
method = "glm",
formula = y ~ poly(x, i = 2),
se = FALSE) +
geom_smooth(aes(color = "3"),
method = "glm",
formula = y ~ poly(x, i = 3),
se = FALSE) +
geom_smooth(aes(color = "4"),
method = "glm",
formula = y ~ poly(x, i = 4),
se = FALSE) +
geom_smooth(aes(color = "5"),
method = "glm",
formula = y ~ poly(x, i = 5),
se = FALSE) +
scale_color_brewer(type = "qual", palette = "Dark2") +
labs(x = "Horsepower",
y = "MPG",
color = "Highest-order\npolynomial")
# function to estimate model using training set and generate fit statistics
# using the test set
poly_results <- function(train, test, i) {
# Fit the model to the training set
mod <- glm(mpg ~ poly(horsepower, i, raw = TRUE), data = train)
# `augment` will save the predictions with the test data set
res <- augment(mod, newdata = test) %>%
mse(truth = mpg, estimate = .fitted)
# Return the test data set with the additional columns
res
}
# function to return MSE for a specific higher-order polynomial term
poly_mse <- function(i, train, test){
poly_results(train, test, i) %$%
mean(.estimate)
}
cv_mse <- tibble(terms = seq(from = 1, to = 5),
mse_test = map_dbl(terms, poly_mse, auto_train, auto_test))
ggplot(cv_mse, aes(terms, mse_test)) +
geom_line() +
labs(title = "Comparing quadratic linear models",
subtitle = "Using validation set",
x = "Highest-order polynomial",
y = "Mean Squared Error")
Based on the MSE for the validation set, a polynomial model with a quadratic term (\(\text{horsepower}^2\)) produces a lower average error than the standard model. A higher order term such as a fifth-order polynomial leads to an even larger reduction, though increases the complexity of interpreting the model.
Recall our efforts to predict passenger survival during the sinking of the Titanic.
library(titanic)
titanic <- as_tibble(titanic_train) %>%
mutate(Survived = factor(Survived))
titanic %>%
head() %>%
knitr::kable()| PassengerId | Survived | Pclass | Name | Sex | Age | SibSp | Parch | Ticket | Fare | Cabin | Embarked |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 1 | 0 | 3 | Braund, Mr. Owen Harris | male | 22 | 1 | 0 | A/5 21171 | 7.2500 | S | |
| 2 | 1 | 1 | Cumings, Mrs. John Bradley (Florence Briggs Thayer) | female | 38 | 1 | 0 | PC 17599 | 71.2833 | C85 | C |
| 3 | 1 | 3 | Heikkinen, Miss. Laina | female | 26 | 0 | 0 | STON/O2. 3101282 | 7.9250 | S | |
| 4 | 1 | 1 | Futrelle, Mrs. Jacques Heath (Lily May Peel) | female | 35 | 1 | 0 | 113803 | 53.1000 | C123 | S |
| 5 | 0 | 3 | Allen, Mr. William Henry | male | 35 | 0 | 0 | 373450 | 8.0500 | S | |
| 6 | 0 | 3 | Moran, Mr. James | male | NA | 0 | 0 | 330877 | 8.4583 | Q |
survive_age_woman_x <- glm(Survived ~ Age * Sex, data = titanic,
family = binomial)
summary(survive_age_woman_x)##
## Call:
## glm(formula = Survived ~ Age * Sex, family = binomial, data = titanic)
##
## Deviance Residuals:
## Min 1Q Median 3Q Max
## -1.9401 -0.7136 -0.5883 0.7626 2.2455
##
## Coefficients:
## Estimate Std. Error z value Pr(>|z|)
## (Intercept) 0.59380 0.31032 1.913 0.05569 .
## Age 0.01970 0.01057 1.863 0.06240 .
## Sexmale -1.31775 0.40842 -3.226 0.00125 **
## Age:Sexmale -0.04112 0.01355 -3.034 0.00241 **
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## (Dispersion parameter for binomial family taken to be 1)
##
## Null deviance: 964.52 on 713 degrees of freedom
## Residual deviance: 740.40 on 710 degrees of freedom
## (177 observations deleted due to missingness)
## AIC: 748.4
##
## Number of Fisher Scoring iterations: 4
We can use the same validation set approach to evaluate the model’s accuracy. For classification models, instead of using MSE we examine the error rate. That is, of all the predictions generated for the test set, what percentage of predictions are incorrect? The goal is to minimize this value as much as possible (ideally, until we make no errors and our error rate is \(0\)).
# function to convert log-odds to probabilities
logit2prob <- function(x){
exp(x) / (1 + exp(x))
}# split the data into training and validation sets
titanic_split <- initial_split(data = titanic, prop = 0.5)
# fit model to training data
train_model <- glm(Survived ~ Age * Sex, data = training(titanic_split),
family = binomial)
summary(train_model)##
## Call:
## glm(formula = Survived ~ Age * Sex, family = binomial, data = training(titanic_split))
##
## Deviance Residuals:
## Min 1Q Median 3Q Max
## -2.1511 -0.7346 -0.5386 0.7339 2.2216
##
## Coefficients:
## Estimate Std. Error z value Pr(>|z|)
## (Intercept) 0.17464 0.41877 0.417 0.676659
## Age 0.03570 0.01525 2.342 0.019198 *
## Sexmale -0.59608 0.56604 -1.053 0.292313
## Age:Sexmale -0.06833 0.01994 -3.426 0.000612 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## (Dispersion parameter for binomial family taken to be 1)
##
## Null deviance: 478.37 on 353 degrees of freedom
## Residual deviance: 361.88 on 350 degrees of freedom
## (92 observations deleted due to missingness)
## AIC: 369.88
##
## Number of Fisher Scoring iterations: 4
# calculate predictions using validation set
x_test_accuracy <- augment(train_model, newdata = testing(titanic_split)) %>%
as_tibble() %>%
mutate(.prob = logit2prob(.fitted),
.pred = factor(round(.prob)))
# calculate test error rate
accuracy(x_test_accuracy, truth = Survived, estimate = .pred)## # A tibble: 1 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 accuracy binary 0.783
This interactive model generates an error rate of 21.7%. We could compare this error rate to alternative classification models, either other logistic regression models (using different formulas) or a tree-based method.
There are two main problems with validation sets:
Validation estimates of the test error rates can be highly variable depending on which observations are sampled into the training and test sets. See what happens if we repeat the sampling, estimation, and validation procedure for the Auto data set:
mse_variable <- function(Auto){
auto_split <- initial_split(Auto, prop = 0.5)
auto_train <- training(auto_split)
auto_test <- testing(auto_split)
cv_mse <- tibble(terms = seq(from = 1, to = 5),
mse_test = map_dbl(terms, poly_mse, auto_train, auto_test))
return(cv_mse)
}
rerun(10, mse_variable(Auto)) %>%
bind_rows(.id = "id") %>%
ggplot(aes(terms, mse_test, color = id)) +
geom_line() +
labs(title = "Variability of MSE estimates",
subtitle = "Using the validation set approach",
x = "Degree of Polynomial",
y = "Mean Squared Error") +
theme(legend.position = "none")
Depending on the specific training/test split, our MSE varies by up to 5.
If you don’t have a large data set, you’ll have to dramatically shrink the size of your training set. Most statistical learning methods perform better with more observations - if you don’t have enough data in the training set, you might overestimate the error rate in the test set.
An alternative method is leave-one-out cross validation (LOOCV). Like with the validation set approach, you split the data into two parts. However the difference is that you only remove one observation for the test set, and keep all remaining observations in the training set. The statistical learning method is fit on the \(N-1\) training set. You then use the held-out observation to calculate the \(MSE = (y_1 - \hat{y}_1)^2\) which should be an unbiased estimator of the test error. Because this MSE is highly dependent on which observation is held out, we repeat this process for every single observation in the data set. Mathematically, this looks like:
\[CV_{(N)} = \frac{1}{N} \sum_{i = 1}^{N}{MSE_i}\]
This method produces estimates of the error rate that are approximately unbiased and are non-varying for a given dataset, unlike the validation set approach where the MSE estimate is highly dependent on the sampling process for training/test sets. However it can have have variance because the \(N\) “training sets” are so similar to one another. LOOCV is also highly flexible and works with any kind of predictive modeling.
Of course the downside is that this method is computationally difficult. You have to estimate \(N\) different models - if you have a large \(N\) or each individual model takes a long time to compute, you may be stuck waiting a long time for the computer to finish its calculations.
We can use the loo_cv() function in the rsample library to compute the LOOCV of any linear or logistic regression model. It takes a single argument: the data frame being cross-validated. For the Auto dataset, this looks like:
loocv_data <- loo_cv(Auto)
loocv_data## # Leave-one-out cross-validation
## # A tibble: 392 x 2
## splits id
## <list> <chr>
## 1 <split [391/1]> Resample1
## 2 <split [391/1]> Resample2
## 3 <split [391/1]> Resample3
## 4 <split [391/1]> Resample4
## 5 <split [391/1]> Resample5
## 6 <split [391/1]> Resample6
## 7 <split [391/1]> Resample7
## 8 <split [391/1]> Resample8
## 9 <split [391/1]> Resample9
## 10 <split [391/1]> Resample10
## # … with 382 more rows
Each element of loocv_data$splits is an object of class rsplit. This is essentially an efficient container for storing both the analysis data (i.e. the training data set) and the assessment data (i.e. the validation data set). If we print the contents of a single rsplit object:
first_resample <- loocv_data$splits[[1]]
first_resample## <391/1/392>
This tells us there are 391 observations in the analysis set, 1 observation in the assessment set, and the original data set contained 392 observations. To extract the analysis/assessment sets, use analysis() or assessment() respectively:
training(first_resample)## # A tibble: 391 x 9
## mpg cylinders displacement horsepower weight acceleration year origin
## <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 18 8 307 130 3504 12 70 1
## 2 15 8 350 165 3693 11.5 70 1
## 3 18 8 318 150 3436 11 70 1
## 4 16 8 304 150 3433 12 70 1
## 5 17 8 302 140 3449 10.5 70 1
## 6 15 8 429 198 4341 10 70 1
## 7 14 8 454 220 4354 9 70 1
## 8 14 8 440 215 4312 8.5 70 1
## 9 14 8 455 225 4425 10 70 1
## 10 15 8 390 190 3850 8.5 70 1
## # … with 381 more rows, and 1 more variable: name <fct>
assessment(first_resample)## # A tibble: 1 x 9
## mpg cylinders displacement horsepower weight acceleration year origin
## <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 25 4 113 95 2228 14 71 3
## # … with 1 more variable: name <fct>
Given this new loocv_data data frame, we write a function that will, for each resample:
broom packageholdout_results <- function(splits) {
# Fit the model to the N-1
mod <- glm(mpg ~ horsepower, data = analysis(splits))
# Save the heldout observation
holdout <- assessment(splits)
# `augment` will save the predictions with the holdout data set
res <- augment(mod, newdata = holdout) %>%
# calculate the metric
mse(truth = mpg, estimate = .fitted)
# Return the metrics
res
}This function works for a single resample:
holdout_results(loocv_data$splits[[1]])## # A tibble: 1 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 mse standard 0.00355
To compute the MSE for each heldout observation (i.e. estimate the test MSE for each of the \(N\) observations), we use the map() function from the purrr package to estimate the model for each training test, then calculate the MSE for each observation in each test set:
loocv_data_poly1 <- loocv_data %>%
mutate(results = map(splits, holdout_results)) %>%
unnest(results) %>%
spread(.metric, .estimate)
loocv_data_poly1## # A tibble: 392 x 4
## splits id .estimator mse
## <list> <chr> <chr> <dbl>
## 1 <split [391/1]> Resample1 standard 0.00355
## 2 <split [391/1]> Resample2 standard 1.25
## 3 <split [391/1]> Resample3 standard 19.6
## 4 <split [391/1]> Resample4 standard 2.42
## 5 <split [391/1]> Resample5 standard 16.7
## 6 <split [391/1]> Resample6 standard 97.0
## 7 <split [391/1]> Resample7 standard 57.7
## 8 <split [391/1]> Resample8 standard 1.77
## 9 <split [391/1]> Resample9 standard 15.3
## 10 <split [391/1]> Resample10 standard 24.2
## # … with 382 more rows
Now we can compute the overall LOOCV MSE for the data set by calculating the mean of the mse column:
loocv_data_poly1 %>%
summarize(mse = mean(mse))## # A tibble: 1 x 1
## mse
## <dbl>
## 1 24.2
We can also use this method to compare the optimal number of polynomial terms as before.
# modified function to estimate model with varying highest order polynomial
holdout_results <- function(splits, i) {
# Fit the model to the N-1
mod <- glm(mpg ~ poly(horsepower, i), data = analysis(splits))
# `augment` will save the predictions with the holdout data set
res <- augment(mod, newdata = assessment(splits)) %>%
# calculate the metric
mse(truth = mpg, estimate = .fitted)
# Return the assessment data set with the additional columns
res
}
# function to return MSE for a specific higher-order polynomial term
poly_mse <- function(i, loocv_data){
loocv_mod <- loocv_data %>%
mutate(results = map(splits, holdout_results, i)) %>%
unnest(results) %>%
spread(.metric, .estimate)
mean(loocv_mod$mse)
}
cv_mse <- tibble(terms = seq(from = 1, to = 5),
mse_loocv = map_dbl(terms, poly_mse, loocv_data))
cv_mse## # A tibble: 5 x 2
## terms mse_loocv
## <int> <dbl>
## 1 1 24.2
## 2 2 19.2
## 3 3 19.3
## 4 4 19.4
## 5 5 19.0
ggplot(cv_mse, aes(terms, mse_loocv)) +
geom_line() +
labs(title = "Comparing quadratic linear models",
subtitle = "Using LOOCV",
x = "Highest-order polynomial",
y = "Mean Squared Error")
And arrive at a similar conclusion. There may be a very marginal advantage to adding a fifth-order polynomial, but not substantial enough for the additional complexity over a mere second-order polynomial.
Let’s verify the error rate of our interactive terms model for the Titanic data set:
# function to generate assessment statistics for titanic model
holdout_results <- function(splits) {
# Fit the model to the N-1
mod <- glm(Survived ~ Age * Sex, data = analysis(splits),
family = binomial)
# `augment` will save the predictions with the holdout data set
res <- augment(mod, newdata = assessment(splits)) %>%
as_tibble() %>%
mutate(.prob = logit2prob(.fitted),
.pred = round(.prob))
# Return the assessment data set with the additional columns
res
}
titanic_loocv <- loo_cv(titanic) %>%
mutate(results = map(splits, holdout_results)) %>%
unnest(results) %>%
mutate(.pred = factor(.pred)) %>%
group_by(id) %>%
accuracy(truth = Survived, estimate = .pred)
1 - mean(titanic_loocv$.estimate, na.rm = TRUE)## [1] 0.219888
In a classification problem, the LOOCV tells us the average error rate based on our predictions. So here, it tells us that the interactive Age * Sex model has a 22% error rate. This is similar to the validation set result (21.7%).
A less computationally-intensive approach to cross validation is \(K\)-fold cross-validation. Rather than dividing the data into \(N\) groups, one divides the observations into \(K\) groups, or folds, of approximately equal size. The first fold is treated as the validation set, and the model is estimated on the remaining \(K-1\) folds. This process is repeated \(K\) times, with each fold serving as the validation set precisely once. The \(K\)-fold CV estimate is calculated by averaging the MSE values for each fold:
\[CV_{(K)} = \frac{1}{K} \sum_{i = 1}^{K}{MSE_i}\]
As you may have noticed, LOOCV is a special case of \(K\)-fold cross-validation where \(K = N\). More typically researchers will use \(K=5\) or \(K=10\) depending on the size of the data set and the complexity of the statistical model.
Let’s go back to the Auto data set. Instead of LOOCV, let’s use 10-fold CV to compare the different polynomial models.
# modified function to estimate model with varying highest order polynomial
holdout_results <- function(splits, i) {
# Fit the model to the training set
mod <- glm(mpg ~ poly(horsepower, i), data = analysis(splits))
# Save the heldout observations
holdout <- assessment(splits)
# `augment` will save the predictions with the holdout data set
res <- augment(mod, newdata = holdout) %>%
# calculate the metric
mse(truth = mpg, estimate = .fitted)
# Return the assessment data set with the additional columns
res
}
# function to return MSE for a specific higher-order polynomial term
poly_mse <- function(i, vfold_data){
vfold_mod <- vfold_data %>%
mutate(results = map(splits, holdout_results, i)) %>%
unnest(results) %>%
spread(.metric, .estimate)
mean(vfold_mod$mse)
}
# split Auto into 10 folds
auto_cv10 <- vfold_cv(data = Auto, v = 10)
cv_mse <- tibble(terms = seq(from = 1, to = 5),
mse_vfold = map_dbl(terms, poly_mse, auto_cv10))
cv_mse## # A tibble: 5 x 2
## terms mse_vfold
## <int> <dbl>
## 1 1 24.2
## 2 2 19.3
## 3 3 19.4
## 4 4 19.6
## 5 5 19.3
How do these results compare to the LOOCV values?
auto_loocv <- loo_cv(Auto)
tibble(terms = seq(from = 1, to = 5),
`10-fold` = map_dbl(terms, poly_mse, auto_cv10),
LOOCV = map_dbl(terms, poly_mse, auto_loocv)
) %>%
gather(method, MSE, -terms) %>%
ggplot(aes(terms, MSE, color = method)) +
geom_line() +
labs(title = "MSE estimates",
x = "Degree of Polynomial",
y = "Mean Squared Error",
color = "CV Method")
Pretty much the same results.
library(profvis)
profvis({
tibble(terms = seq(from = 1, to = 5),
mse_vfold = map_dbl(terms, poly_mse, auto_loocv))
})profvis({
tibble(terms = seq(from = 1, to = 5),
mse_vfold = map_dbl(terms, poly_mse, auto_cv10))
})On my machine, 10-fold CV was about 40 times faster than LOOCV. Again, estimating \(K=10\) models is going to be much easier than estimating \(K=392\) models.
You’ve gotten the idea by now, but let’s do it one more time on our interactive Titanic model.
# function to generate assessment statistics for titanic model
holdout_results <- function(splits) {
# Fit the model to the training set
mod <- glm(Survived ~ Age * Sex, data = analysis(splits),
family = binomial)
# `augment` will save the predictions with the holdout data set
res <- augment(mod, newdata = assessment(splits)) %>%
as_tibble() %>%
mutate(.prob = logit2prob(.fitted),
.pred = round(.prob))
# Return the assessment data set with the additional columns
res
}
titanic_cv10 <- vfold_cv(data = titanic, v = 10) %>%
mutate(results = map(splits, holdout_results)) %>%
unnest(results) %>%
mutate(.pred = factor(.pred)) %>%
group_by(id) %>%
accuracy(truth = Survived, estimate = .pred)
1 - mean(titanic_cv10$.estimate, na.rm = TRUE)## [1] 0.2200643
Not a large difference from the LOOCV approach, but it take much less time to compute.
Ignoring the computational efficiency concerns, why not always estimate cross-validation with \(K=N\)? Or more generally, what is the optimal value for \(K\)? It depends. Well that is not very helpful.
With more explanation, it depends on how we wish to handle the bias-variance tradeoff. LOOCV is a low-bias, high-variance method. That is, it provides unbiased estimates of the test error since each training set contains \(N-1\) observations. This is almost as many observations as contained in the full data set. \(K\)-fold CV for \(K=5\) or \(10\) leads to an intermediate amount of bias, since each training set contains \(\frac{(K-1)N}{K}\) observations. This is fewer than LOOCV, but more than a standard validation set approach with just a single split into training and validation sets. If all we care about is bias, we should prefer LOOCV.
However, recall the contributors to a model’s error:
\[\text{Error} = \text{Irreducible Error} + \text{Bias}^2 + \text{Variance}\]
We also should be concerned with the variance of the model. LOOCV has a higher variance than \(K\)-fold with \(K < N\). When we perform LOOCV, we are averaging the outputs of \(N\) fitted models which are trained on nearly entirely identical sets of observations. The results will be highly correlated with one another. In contrast, \(K\)-fold CV with \(K < N\) averages the output of \(K\) fitted models that are less correlated with one another, since the data sets are not as identical. Since the mean of many highly correlated quantities has higher variance than the mean of quantities with less correlation, the test error estimate from LOOCV has higher variance than the test error estimate from \(K\)-fold CV.
Given these considerations, a typical approach uses \(K=5\) or \(K=10\). Empirical research (see Breiman and Spector (1992), Kohavi and others (1995)) shows that cross-validation with these number of folds suffers neither excessively high bias nor excessively high variance.
To ensure each set is approximately similar to one another in every important aspect, we use random sampling without replacement to partition the data set. Alternative approaches include:
Cross-validation with time series data - Bergmeir and Benítez (2012) evaluate multiple forms of cross-validation methods for time series models. For some types of models, standard cross-validation techniques can be employed without bias. In other situations, a standard approach is to partition the data temporally. For instance, if you have 10 years of observations for a given unit you may reserve the first 8 years for the training set and the last 2 years for the test set. Other forms of cross-validation use rolling test sets whereby one generates a series of test sets, each containing a single observation. The training set consists only of observations that occurred prior to the observation that forms the test set.
devtools::session_info()## ─ Session info ──────────────────────────────────────────────────────────
## setting value
## version R version 3.5.2 (2018-12-20)
## os macOS Mojave 10.14.2
## system x86_64, darwin15.6.0
## ui X11
## language (EN)
## collate en_US.UTF-8
## ctype en_US.UTF-8
## tz America/Chicago
## date 2019-01-25
##
## ─ Packages ──────────────────────────────────────────────────────────────
## package * version date lib
## assertthat 0.2.0 2017-04-11 [2]
## backports 1.1.3 2018-12-14 [2]
## base64enc 0.1-3 2015-07-28 [2]
## bayesplot 1.6.0 2018-08-02 [2]
## bindr 0.1.1 2018-03-13 [2]
## bindrcpp * 0.2.2 2018-03-29 [1]
## blogdown 0.9.4 2018-11-26 [1]
## bookdown 0.9 2018-12-21 [1]
## broom * 0.5.1 2018-12-05 [2]
## callr 3.1.1 2018-12-21 [2]
## cellranger 1.1.0 2016-07-27 [2]
## class 7.3-15 2019-01-01 [2]
## cli 1.0.1 2018-09-25 [1]
## codetools 0.2-16 2018-12-24 [2]
## colorspace 1.3-2 2016-12-14 [2]
## colourpicker 1.0 2017-09-27 [2]
## crayon 1.3.4 2017-09-16 [2]
## crosstalk 1.0.0 2016-12-21 [2]
## desc 1.2.0 2018-05-01 [2]
## devtools 2.0.1 2018-10-26 [1]
## dials * 0.0.2 2018-12-09 [1]
## digest 0.6.18 2018-10-10 [1]
## dplyr * 0.7.8 2018-11-10 [1]
## DT 0.5 2018-11-05 [2]
## dygraphs 1.1.1.6 2018-07-11 [2]
## evaluate 0.12 2018-10-09 [2]
## forcats * 0.3.0 2018-02-19 [2]
## fs 1.2.6 2018-08-23 [1]
## generics 0.0.2 2018-11-29 [1]
## ggplot2 * 3.1.0 2018-10-25 [1]
## ggridges 0.5.1 2018-09-27 [2]
## glue 1.3.0 2018-07-17 [2]
## gower 0.1.2 2017-02-23 [2]
## gridExtra 2.3 2017-09-09 [2]
## gtable 0.2.0 2016-02-26 [2]
## gtools 3.8.1 2018-06-26 [2]
## haven 2.0.0 2018-11-22 [2]
## here * 0.1 2017-05-28 [2]
## hms 0.4.2 2018-03-10 [2]
## htmltools 0.3.6 2017-04-28 [1]
## htmlwidgets 1.3 2018-09-30 [2]
## httpuv 1.4.5.1 2018-12-18 [2]
## httr 1.4.0 2018-12-11 [2]
## igraph 1.2.2 2018-07-27 [2]
## infer * 0.4.0 2018-11-15 [1]
## inline 0.3.15 2018-05-18 [2]
## ipred 0.9-8 2018-11-05 [1]
## janeaustenr 0.1.5 2017-06-10 [2]
## jsonlite 1.6 2018-12-07 [2]
## knitr 1.21 2018-12-10 [2]
## later 0.7.5 2018-09-18 [2]
## lattice 0.20-38 2018-11-04 [2]
## lava 1.6.4 2018-11-25 [2]
## lazyeval 0.2.1 2017-10-29 [2]
## lme4 1.1-19 2018-11-10 [2]
## loo 2.0.0 2018-04-11 [2]
## lubridate 1.7.4 2018-04-11 [2]
## magrittr * 1.5 2014-11-22 [2]
## markdown 0.9 2018-12-07 [2]
## MASS 7.3-51.1 2018-11-01 [2]
## Matrix 1.2-15 2018-11-01 [2]
## matrixStats 0.54.0 2018-07-23 [2]
## memoise 1.1.0 2017-04-21 [2]
## mime 0.6 2018-10-05 [1]
## miniUI 0.1.1.1 2018-05-18 [2]
## minqa 1.2.4 2014-10-09 [2]
## modelr 0.1.2 2018-05-11 [2]
## munsell 0.5.0 2018-06-12 [2]
## nlme 3.1-137 2018-04-07 [2]
## nloptr 1.2.1 2018-10-03 [2]
## nnet 7.3-12 2016-02-02 [2]
## parsnip * 0.0.1 2018-11-12 [1]
## pillar 1.3.1 2018-12-15 [2]
## pkgbuild 1.0.2 2018-10-16 [1]
## pkgconfig 2.0.2 2018-08-16 [2]
## pkgload 1.0.2 2018-10-29 [1]
## plyr 1.8.4 2016-06-08 [2]
## prettyunits 1.0.2 2015-07-13 [2]
## pROC 1.13.0 2018-09-24 [1]
## processx 3.2.1 2018-12-05 [2]
## prodlim 2018.04.18 2018-04-18 [2]
## promises 1.0.1 2018-04-13 [2]
## ps 1.3.0 2018-12-21 [2]
## purrr * 0.2.5 2018-05-29 [2]
## R6 2.3.0 2018-10-04 [1]
## rcfss * 0.1.5 2019-01-24 [1]
## Rcpp 1.0.0 2018-11-07 [1]
## readr * 1.3.1 2018-12-21 [2]
## readxl 1.2.0 2018-12-19 [2]
## recipes * 0.1.4 2018-11-19 [1]
## remotes 2.0.2 2018-10-30 [1]
## reshape2 1.4.3 2017-12-11 [2]
## rlang 0.3.0.1 2018-10-25 [1]
## rmarkdown 1.11 2018-12-08 [2]
## rpart 4.1-13 2018-02-23 [1]
## rprojroot 1.3-2 2018-01-03 [2]
## rsample * 0.0.3 2018-11-20 [1]
## rsconnect 0.8.12 2018-12-05 [2]
## rstan 2.18.2 2018-11-07 [2]
## rstanarm 2.18.2 2018-11-10 [2]
## rstantools 1.5.1 2018-08-22 [2]
## rstudioapi 0.8 2018-10-02 [1]
## rvest 0.3.2 2016-06-17 [2]
## scales * 1.0.0 2018-08-09 [1]
## sessioninfo 1.1.1 2018-11-05 [1]
## shiny 1.2.0 2018-11-02 [2]
## shinyjs 1.0 2018-01-08 [2]
## shinystan 2.5.0 2018-05-01 [2]
## shinythemes 1.1.2 2018-11-06 [2]
## SnowballC 0.5.1 2014-08-09 [2]
## StanHeaders 2.18.0-1 2018-12-13 [2]
## stringi 1.2.4 2018-07-20 [2]
## stringr * 1.3.1 2018-05-10 [2]
## survival 2.43-3 2018-11-26 [2]
## testthat 2.0.1 2018-10-13 [2]
## threejs 0.3.1 2017-08-13 [2]
## tibble * 2.0.0 2019-01-04 [2]
## tidymodels * 0.0.2 2018-11-27 [1]
## tidyposterior 0.0.2 2018-11-15 [1]
## tidypredict 0.2.1 2018-12-20 [1]
## tidyr * 0.8.2 2018-10-28 [2]
## tidyselect 0.2.5 2018-10-11 [1]
## tidytext 0.2.0 2018-10-17 [1]
## tidyverse * 1.2.1 2017-11-14 [2]
## timeDate 3043.102 2018-02-21 [2]
## tokenizers 0.2.1 2018-03-29 [2]
## usethis 1.4.0 2018-08-14 [1]
## withr 2.1.2 2018-03-15 [2]
## xfun 0.4 2018-10-23 [1]
## xml2 1.2.0 2018-01-24 [2]
## xtable 1.8-3 2018-08-29 [2]
## xts 0.11-2 2018-11-05 [2]
## yaml 2.2.0 2018-07-25 [2]
## yardstick * 0.0.2 2018-11-05 [1]
## zoo 1.8-4 2018-09-19 [2]
## source
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## Github (rstudio/blogdown@b2e1ed4)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.2)
## CRAN (R 3.5.0)
## CRAN (R 3.5.2)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.1)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.1)
## CRAN (R 3.5.0)
## CRAN (R 3.5.2)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.2)
## CRAN (R 3.5.2)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.2)
## CRAN (R 3.5.0)
## CRAN (R 3.5.2)
## CRAN (R 3.5.1)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.1)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## local
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.2)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.2)
## CRAN (R 3.5.1)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
## CRAN (R 3.5.0)
##
## [1] /Users/soltoffbc/Library/R/3.5/library
## [2] /Library/Frameworks/R.framework/Versions/3.5/Resources/library
Bergmeir, Christoph, and José M Benítez. 2012. “On the Use of Cross-Validation for Time Series Predictor Evaluation.” Information Sciences 191. Elsevier: 192–213.
Breiman, Leo, and Philip Spector. 1992. “Submodel Selection and Evaluation in Regression. the X-Random Case.” International Statistical Review/Revue Internationale de Statistique. JSTOR, 291–319.
Friedman, Jerome, Trevor Hastie, and Robert Tibshirani. 2001. The Elements of Statistical Learning. Vol. 1. 10. Springer series in statistics New York, NY, USA: https://web.stanford.edu/~hastie/ElemStatLearn/.
James, Gareth, Daniela Witten, Trevor Hastie, and Robert Tibshirani. 2013. An Introduction to Statistical Learning. Vol. 112. Springer. http://www-bcf.usc.edu/~gareth/ISL/.
Kohavi, Ron, and others. 1995. “A Study of Cross-Validation and Bootstrap for Accuracy Estimation and Model Selection.” In Ijcai, 14:1137–45. 2. Montreal, Canada.
For educational purposes, here we will omit the test set. In a real-world situation, we would first partition out a test set of data.↩
The actual value you use is irrelevant. Just be sure to set it in the script, otherwise R will randomly pick one each time you start a new session.↩
The default family for glm() is gaussian(), or the Gaussian distribution. You probably know it by its other name, the Normal distribution.↩